/*
 * user.c
 *
 * Copyright (C) 2013 Aleksandar Andrejevic <theflash@sdf.lonestar.org>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <user.h>
#include <process.h>
#include <heap.h>
#include <timer.h>

static user_t *reference_user_by_id(uid_t uid)
{
    user_t *user = NULL;

    while (enum_objects_by_type(OBJECT_USER, (object_t**)&user) == ERR_SUCCESS)
    {
        if (user->uid == uid) return user;
    }

    return NULL;
}

static dword_t add_user(uid_t uid, const char *name, dword_t *password_hash, qword_t privileges)
{
    user_t *user;

    if ((user = reference_user_by_id(uid)) != NULL)
    {
        dereference(&user->header);
        return ERR_EXISTS;
    }

    if (reference_by_name(name, OBJECT_USER, (object_t**)&user))
    {
        dereference(&user->header);
        return ERR_EXISTS;
    }

    user = (user_t*)malloc(sizeof(user_t));
    if (user == NULL) return ERR_NOMEMORY;

    init_object(&user->header, name, OBJECT_USER);
    user->uid = uid;
    memcpy(user->password_hash, password_hash, sizeof(user->password_hash));
    user->privileges = privileges;

    dword_t ret = create_object(&user->header);
    if (ret != ERR_SUCCESS)
    {
        free(user->header.name);
        free(user);
    }

    return ret;
}

static void sha256_compute(byte_t *buffer, size_t size, dword_t *sum)
{
    static const dword_t round_constants[] =
    {
        0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5,
        0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174,
        0xE49B69C1, 0xEFBE4786, 0x0FC19DC6, 0x240CA1CC, 0x2DE92C6F, 0x4A7484AA, 0x5CB0A9DC, 0x76F988DA,
        0x983E5152, 0xA831C66D, 0xB00327C8, 0xBF597FC7, 0xC6E00BF3, 0xD5A79147, 0x06CA6351, 0x14292967,
        0x27B70A85, 0x2E1B2138, 0x4D2C6DFC, 0x53380D13, 0x650A7354, 0x766A0ABB, 0x81C2C92E, 0x92722C85,
        0xA2BFE8A1, 0xA81A664B, 0xC24B8B70, 0xC76C51A3, 0xD192E819, 0xD6990624, 0xF40E3585, 0x106AA070,
        0x19A4C116, 0x1E376C08, 0x2748774C, 0x34B0BCB5, 0x391C0CB3, 0x4ED8AA4A, 0x5B9CCA4F, 0x682E6FF3,
        0x748F82EE, 0x78A5636F, 0x84C87814, 0x8CC70208, 0x90BEFFFA, 0xA4506CEB, 0xBEF9A3F7, 0xC67178F2
    };

    size_t i, j;
    dword_t message[64];
    size_t num_chunks = (size >> 6) + ((size & 0x3F) > 55) ? 2 : 1;

    sum[0] = 0x6A09E667;
    sum[1] = 0xBB67AE85;
    sum[2] = 0x3C6EF372;
    sum[3] = 0xA54FF53A;
    sum[4] = 0x510E527F;
    sum[5] = 0x9B05688C;
    sum[6] = 0x1F83D9AB;
    sum[7] = 0x5BE0CD19;

    for (i = 0; i < num_chunks; i++)
    {
        for (j = 0; j < 64; j++)
        {
            byte_t value = 0;

            if ((i << 6) + j < size) value = buffer[(i << 6) + j];
            else if ((i << 6) + j == size) value = 0x80;
            else if (i == num_chunks - 1 && j >= 56) value = ((size << 3) >> ((63 - j) << 3)) & 0xFF;

            switch (j & 3)
            {
            case 0:
                message[j >> 2] = value << 24;
                break;
            case 1:
                message[j >> 2] |= value << 16;
                break;
            case 2:
                message[j >> 2] |= value << 8;
                break;
            case 3:
                message[j >> 2] |= value;
                break;
            }
        }

        for (j = 16; j < 64; j++)
        {
            message[j] = message[j - 7] + message[j - 16]
                         + (((message[j - 15] >> 7) | (message[j - 15] << 25))
                         ^ ((message[j - 15] >> 18) | (message[j - 15] << 14))
                         ^ (message[j - 15] >> 3))
                         + (((message[j - 2] >> 17) | (message[j - 2] << 15))
                         ^ ((message[j - 2] >> 19) | (message[j - 2] << 13))
                         ^ (message[j - 2] >> 10));
        }

        dword_t vars[8];
        for (j = 0; j < 8; j++) vars[j] = sum[j];

        for (j = 0; j < 64; j++)
        {
            dword_t temp1 = vars[7] + (((vars[4] >> 6) | (vars[4] << 26))
                            ^ ((vars[4] >> 11) | (vars[4] << 21))
                            ^ ((vars[4] >> 25) | (vars[4] << 7)))
                            + ((vars[4] & vars[5]) ^ (~vars[4] & vars[6]))
                            + round_constants[j] + message[j];

            dword_t temp2 = (((vars[0] >> 2) | (vars[0] << 30))
                            ^ ((vars[0] >> 13) | (vars[0] << 19))
                            ^ ((vars[0] >> 22) | (vars[0] << 10)))
                            + ((vars[0] & vars[1]) ^ (vars[0] & vars[2]) ^ (vars[1] & vars[2]));

            vars[7] = vars[6];
            vars[6] = vars[5];
            vars[5] = vars[4];
            vars[4] = vars[3] + temp1;
            vars[3] = vars[2];
            vars[2] = vars[1];
            vars[1] = vars[0];
            vars[0] = temp1 + temp2;
        }

        for (j = 0; j < 8; j++) sum[j] += vars[j];
    }
}

bool_t check_privileges(qword_t privilege_mask)
{
    qword_t privileges = get_current_process()->current_user->privileges;
    return ((privileges & privilege_mask) == privilege_mask);
}

dword_t get_current_uid(void)
{
    process_t *proc = get_current_process();
    return proc && proc->current_user ? proc->current_user->uid : 0;
}

sysret_t syscall_set_user_id(uid_t uid)
{
    process_t *proc = get_current_process();
    user_t *user = reference_user_by_id(uid);
    if (user == NULL) return ERR_NOTFOUND;

    if (get_previous_mode() == USER_MODE && !check_privileges(PRIVILEGE_CHANGE_UID))
    {
        dereference(&user->header);
        return ERR_FORBIDDEN;
    }

    if (proc->current_user) dereference(&proc->current_user->header);
    proc->current_user = user;

    return ERR_SUCCESS;
}

sysret_t syscall_revert_user()
{
    process_t *proc = get_current_process();

    if (proc->original_user != proc->current_user)
    {
        if (proc->current_user) dereference(&proc->current_user->header);
        reference(&proc->original_user->header);
        proc->current_user = proc->original_user;
        return ERR_SUCCESS;
    }
    else
    {
        return ERR_INVALID;
    }
}

sysret_t syscall_create_user(uid_t uid, const char *name, dword_t *password_hash, qword_t privileges)
{
    dword_t safe_password_hash[64];

    if (get_previous_mode() == USER_MODE)
    {
        if (!check_privileges(privileges | PRIVILEGE_MANAGE_USERS)) return ERR_FORBIDDEN;
        if (!check_usermode(password_hash, sizeof(safe_password_hash))) return ERR_BADPTR;

        EH_TRY
        {
            memcpy(safe_password_hash, password_hash, sizeof(safe_password_hash));
            password_hash = &safe_password_hash[0];
        }
        EH_CATCH
        {
            EH_ESCAPE(return ERR_BADPTR);
        }
        EH_DONE;
    }

    return add_user(uid, name, password_hash, privileges);
}

sysret_t syscall_delete_user(uid_t uid)
{
    if (get_previous_mode() == USER_MODE && !check_privileges(PRIVILEGE_MANAGE_USERS))
    {
        return ERR_FORBIDDEN;
    }

    process_t *proc = NULL;
    dword_t ret = enum_objects_by_type(OBJECT_PROCESS, (object_t**)&proc);
    ASSERT(ret == ERR_SUCCESS || ret == ERR_NOMORE);

    while (ret == ERR_SUCCESS)
    {
        if (proc->current_user->uid == uid || proc->original_user->uid == uid)
        {
            dereference(&proc->header);
            return ERR_BUSY;
        }

        ret = enum_objects_by_type(OBJECT_PROCESS, (object_t**)&proc);
    }

    ASSERT(ret == ERR_NOMORE);

    user_t *user = reference_user_by_id(uid);
    if (user == NULL) return ERR_NOTFOUND;

    dereference(&user->header);
    dereference(&user->header);

    return ERR_SUCCESS;
}

sysret_t syscall_logon_user(uid_t uid, const char *password)
{
    dword_t ret;
    process_t *proc = get_current_process();
    user_t *user = reference_user_by_id(uid);
    user_t *current_user = get_current_process()->current_user;
    char *safe_password;

    if (user == NULL) return ERR_NOTFOUND;
    if ((timer_get_milliseconds() - current_user->last_login_attempt) < LOGIN_ATTEMPT_TIMEOUT) return ERR_BUSY;

    if (get_previous_mode() == USER_MODE)
    {
        safe_password = copy_user_string(password);
    }
    else
    {
        safe_password = (char*)password;
    }

    if (uid == 0)
    {
        current_user->last_login_attempt = timer_get_milliseconds();
        ret = ERR_INVALID;
        goto cleanup;
    }

    char *salted = (char*)__builtin_alloca(strlen(user->name) + strlen(password) + 4);
    strcpy(salted, "%");
    strcat(salted, user->name);
    strcat(salted, "%");
    strcat(salted, password);
    strcat(salted, "%");

    dword_t hash_sum[64];
    sha256_compute((byte_t*)salted, strlen(salted), hash_sum);

    if (memcmp(hash_sum, user->password_hash, sizeof(hash_sum)) == 0)
    {
        reference(&user->header);
        proc->current_user = user;
        ret = ERR_SUCCESS;
    }
    else
    {
        user->last_login_attempt = timer_get_milliseconds();
        ret = ERR_INVALID;
    }

cleanup:
    dereference(&user->header);
    if (get_previous_mode() == USER_MODE) free(safe_password);
    return ret;
}

sysret_t syscall_query_user(uid_t uid, user_info_t info_type, void *buffer, dword_t size)
{
    dword_t ret = ERR_SUCCESS;
    void *safe_buffer;

    user_t *user = reference_user_by_id(uid);
    if (user == NULL) return ERR_NOTFOUND;

    if (get_previous_mode() == USER_MODE)
    {
        if (!check_usermode(buffer, size))
        {
            ret = ERR_BADPTR;
            goto cleanup;
        }

        safe_buffer = malloc(size);

        if (safe_buffer == NULL)
        {
            ret = ERR_NOMEMORY;
            goto cleanup;
        }

        memset(safe_buffer, 0, size);
    }
    else
    {
        safe_buffer = buffer;
    }

    switch (info_type)
    {
    case USER_NAME_INFO:
        if (size >= sizeof(strlen(user->name) + 1)) strcpy(safe_buffer, user->name);
        else ret = ERR_SMALLBUF;

        break;

    case USER_PRIVILEGE_INFO:
        if (size >= sizeof(qword_t)) *((qword_t*)safe_buffer) = user->privileges;
        else ret = ERR_SMALLBUF;

        break;

    default:
        ret = ERR_INVALID;
    }

    if (get_previous_mode() == USER_MODE)
    {
        EH_TRY memcpy(buffer, safe_buffer, size);
        EH_CATCH ret = ERR_BADPTR;
        EH_DONE;
    }

cleanup:
    dereference(&user->header);
    return ret;
}

void user_init()
{
    dword_t blank_password_hash[64] = { 0 };

    if (add_user(0, "root", blank_password_hash, ALL_PRIVILEGES) != ERR_SUCCESS)
    {
        KERNEL_CRASH("Failed to create root user!");
    }

    user_t *root = reference_user_by_id(0);
    reference(&root->header);

    kernel_process->original_user = kernel_process->current_user = root;
}
